"""AWS cost metrics."""
import collections
import datetime
import enum
import itertools
import os
import re

import boto3
from cki_lib import misc
from cki_lib import yaml
from cki_lib.cronjob import CronJob
from cki_lib.logger import get_logger
import dateutil.parser
import prometheus_client

LOGGER = get_logger(__name__)
AWS_CONFIG = yaml.load(contents=os.environ.get('AWS_CONFIG'),
                       file_path=os.environ.get('AWS_CONFIG_PATH')) or {}


_Cost = collections.namedtuple('_CostType', 'type grouping clean_fn')


class Cost(_Cost, enum.Enum):
    """Cost type enumeration."""

    DAILY = _Cost('daily', None, None)
    COMPONENT = _Cost(
        'grouped',
        {'Type': 'TAG', 'Key': 'ServiceComponent'},
        lambda key: re.sub(r'.*\$', '', key))
    OWNER = _Cost(
        'grouped',
        {'Type': 'TAG', 'Key': 'ServiceOwner'},
        lambda key: {
            'ServiceOwner$': 'Unknown',
        }.get(key, re.sub(r'.*\$', '', key)))
    SERVICE = _Cost(
        'grouped',
        {'Type': 'DIMENSION', 'Key': 'SERVICE'},
        lambda key: {
            'Amazon Elastic Compute Cloud - Compute': 'EC2',
            'Amazon Simple Storage Service': 'S3',
            'EC2 - Other': 'EC2-Other',
        }.get(key, key))
    USAGE_TYPE = _Cost(
        'grouped',
        {'Type': 'DIMENSION', 'Key': 'USAGE_TYPE'},
        lambda key: {
            'EBS:VolumeUsage.gp2': 'EBS',
        }.get(key, re.sub('BoxUsage:|Usage|-ByteHrs', '', key)))
    INSTANCE_TYPE = _Cost(
        'grouped',
        {'Type': 'DIMENSION', 'Key': 'INSTANCE_TYPE'},
        lambda key: key)
    OPERATION = _Cost(
        'grouped',
        {'Type': 'DIMENSION', 'Key': 'OPERATION'},
        lambda key: re.sub(r':SV\d\d\d', '', key))


class AwsCostExplorer:
    """Interface with the AWS Cost Explorer API."""

    def __init__(self, days=7, filters=None):
        """Create a new Cost Explorer client."""
        self.session = boto3.Session()
        self.client = self.session.client('ce')

        end = misc.now_tz_utc().date()
        start = end - datetime.timedelta(days=days)
        one_day = end - datetime.timedelta(days=1)
        self.days = days
        self.end = end.strftime('%Y-%m-%d')
        self.start = start.strftime('%Y-%m-%d')
        self.one_day = one_day.strftime('%Y-%m-%d')

        self.common_args = {
            'Granularity': 'DAILY',
            'Metrics': ['NetAmortizedCost'],
        }
        if filters:
            self.common_args['Filter'] = filters

    def costs(self, cost_type: Cost):
        """Return a list with cost statistics."""
        if cost_type.type == 'daily':
            return self.daily_costs()
        return self.grouped_costs(
            cost_type.grouping, cost_type.clean_fn)

    def daily_costs(self):
        """Return a list with daily costs."""
        response = self.client.get_cost_and_usage(
            TimePeriod={'Start': self.start, 'End':  self.end},
            **self.common_args)
        amounts = sorted(response['ResultsByTime'],
                         key=lambda r: r['TimePeriod']['Start'],
                         reverse=True)
        return [(r['TimePeriod']['Start'][5:].replace('-', '/'),
                 float(r['Total']['NetAmortizedCost']['Amount']))
                for r in amounts]

    def total_costs(self):
        """Return the last total daily costs."""
        response = self.client.get_cost_and_usage(
            TimePeriod={'Start': self.one_day, 'End':  self.end},
            **self.common_args)
        return float(response['ResultsByTime'][0]
                     ['Total']['NetAmortizedCost']['Amount'])

    def grouped_costs(self, groups, clean_fn=None):
        """Return a list with grouped averaged costs."""
        response = self.client.get_cost_and_usage(
            TimePeriod={'Start': self.start, 'End':  self.end},
            GroupBy=[groups],
            **self.common_args)
        amounts = {}
        for result in response['ResultsByTime']:
            for group in result['Groups']:
                key = group['Keys'][0]
                amount = float(group['Metrics']['NetAmortizedCost']['Amount']) / self.days
                if clean_fn:
                    key = clean_fn(key)
                amounts.setdefault(key, 0)
                amounts[key] += amount
        return sorted(amounts.items(), key=lambda r: r[1], reverse=True)

    @staticmethod
    def summarize(amounts, item_cutoff, max_items):
        """Format the cost list into a simple string."""
        if item_cutoff:
            amounts = (r for r in amounts if r[1] >= item_cutoff)
        if max_items:
            amounts = itertools.islice(amounts, max_items)
        return ' '.join(f'{r[0]}=${r[1]:.2f}' for r in amounts)


class AwsMetricsDaily(CronJob):
    """Calculate AWS metrics."""

    # once a day slightly after UTC midnight which should return the data of the previous UTC day
    schedule = '10 0 * * *'

    metric_cost = prometheus_client.Gauge(
        'cki_aws_cost',
        'grouped AWS costs',
        ['group', 'type', 'key']
    )
    metric_reserved_instance_end = prometheus_client.Gauge(
        'cki_aws_reserved_instance_end',
        'end date for reserved instances',
        ['reservedinstancesid', 'instancecount', 'instancetype', 'state']
    )
    metric_savings_plan_end = prometheus_client.Gauge(
        'cki_aws_savings_plan_end',
        'end date for savings plans',
        ['savingsplanid', 'savingsplantype', 'commitment', 'state']
    )

    def run(self, **_):
        """Update the metrics."""
        self.update_metric_cost()
        self.update_metric_reserved_instance_expiry()
        self.update_metric_savings_plan_expiry()

    def update_metric_cost(self) -> None:
        """Update metric_cost."""
        for name, filters in misc.get_nested_key(AWS_CONFIG, 'cost/groups', {'all': {}}).items():
            explorer = AwsCostExplorer(days=1, filters=filters)
            for cost_type, cost_group in Cost.__members__.items():
                if cost_type == 'DAILY':
                    continue
                for cost in explorer.costs(cost_group):
                    self.metric_cost.labels(name, cost_type.lower(), cost[0]).set(cost[1])

    def update_metric_reserved_instance_expiry(self) -> None:
        """Update metric_reserved_instance_expiry."""
        ec2_client = boto3.Session().client('ec2')
        for instances in ec2_client.describe_reserved_instances()['ReservedInstances']:
            self.metric_reserved_instance_end.labels(
                instances['ReservedInstancesId'],
                instances['InstanceCount'],
                instances['InstanceType'],
                instances['State'],
            ).set(instances['End'].timestamp())

    def update_metric_savings_plan_expiry(self) -> None:
        """Update metric_savings_plan_expiry."""
        savingsplans_client = boto3.Session().client('savingsplans')
        for plan in savingsplans_client.describe_savings_plans()['savingsPlans']:
            self.metric_savings_plan_end.labels(
                plan['savingsPlanId'],
                plan['savingsPlanType'],
                plan['commitment'],
                plan['state'],
            ).set(dateutil.parser.parse(plan['end']).timestamp())


class AwsMetricsMinutely(CronJob):
    """Calculate AWS metrics."""

    schedule = '*/5 * * * *'  # every 5 minutes

    metric_subnets = prometheus_client.Gauge(
        'cki_aws_subnet_available_ip_address_count',
        'available IP addresses',
        ['vpcid', 'subnetid', 'availabilityzone', 'cidrblock', 'name']
    )

    def run(self, **_) -> None:
        """Update the metrics."""
        self.update_metric_subnets()

    def update_metric_subnets(self) -> None:
        """Update metric_subnets."""
        ec2_client = boto3.Session().client('ec2')
        for subnets in ec2_client.describe_subnets()['Subnets']:
            self.metric_subnets.labels(
                subnets['VpcId'],
                subnets['SubnetId'],
                subnets['AvailabilityZone'],
                subnets['CidrBlock'],
                next((t['Value'] for t in subnets.get('Tags', []) if t['Key'] == 'Name'), ''),
            ).set(subnets['AvailableIpAddressCount'])
